{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MXNet Gluon Multi-GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "MULTI_GPU = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import multiprocessing\n",
    "import logging\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import mxnet as mx\n",
    "from mxnet.io import DataDesc\n",
    "from mxnet import nd, gluon, autograd\n",
    "from mxnet.gluon.data import RecordFileDataset, ArrayDataset, Dataset\n",
    "from mxnet.gluon.data.vision import transforms\n",
    "from mxnet.gluon.data.vision.datasets import ImageFolderDataset\n",
    "from mxnet.gluon.data.dataloader import DataLoader\n",
    "from mxnet.gluon.model_zoo import vision as models\n",
    "from mxnet import recordio\n",
    "\n",
    "from sklearn.metrics.ranking import roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from PIL import Image\n",
    "from common.utils import *\n",
    "from common.params_dense import *\n",
    "import math\n",
    "from time import time\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) \n",
      "[GCC 7.2.0]\n",
      "Numpy:  1.14.1\n",
      "MXNet:  1.3.0\n",
      "GPU:  ['Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB']\n",
      "CUDA Version 9.0.176\n",
      "CuDNN Version  7.0.5\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"MXNet: \", mx.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPUs:  24\n",
      "GPUs:  4\n"
     ]
    }
   ],
   "source": [
    "# User-set\n",
    "# Note if NUM_GPUS > 1 then MULTI_GPU = True and ALL GPUs will be used\n",
    "# Set below to affect batch-size\n",
    "# E.g. 1 GPU = 64, 2 GPUs =64*2, 4 GPUs = 64*4\n",
    "# Note that the effective learning-rate will be decreased this way\n",
    "CPU_COUNT = multiprocessing.cpu_count() \n",
    "GPU_COUNT = len(get_gpu_name())\n",
    "if not MULTI_GPU:\n",
    "    GPU_COUNT = 1\n",
    "print(\"CPUs: \", CPU_COUNT)\n",
    "print(\"GPUs: \", GPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually scale to multi-gpu\n",
    "if MULTI_GPU:\n",
    "    LR *= GPU_COUNT\n",
    "    BATCHSIZE *= (GPU_COUNT)\n",
    "    BATCHSIZE = BATCHSIZE//GPU_COUNT*GPU_COUNT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Download"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model-params\n",
    "# Paths\n",
    "CSV_DEST = \"chestxray\"\n",
    "IMAGE_FOLDER = os.path.join(CSV_DEST, \"images\")\n",
    "LABEL_FILE = os.path.join(CSV_DEST, \"Data_Entry_2017.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please make sure to download\n",
      "https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\n",
      "Data already exists\n",
      "CPU times: user 587 ms, sys: 216 ms, total: 803 ms\n",
      "Wall time: 802 ms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/_pytest/fixtures.py:844: DeprecationWarning: The `convert` argument is deprecated in favor of `converter`.  It will be removed after 2019/01.\n",
      "  params = attr.ib(convert=attr.converters.optional(tuple))\n",
      "/anaconda/envs/py35/lib/python3.5/site-packages/_pytest/fixtures.py:846: DeprecationWarning: The `convert` argument is deprecated in favor of `converter`.  It will be removed after 2019/01.\n",
      "  ids = attr.ib(default=None, convert=_ensure_immutable_ids)\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Download data\n",
    "print(\"Please make sure to download\")\n",
    "print(\"https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\")\n",
    "download_data_chextxray(CSV_DEST)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data prep\n",
    "https://github.com/apache/incubator-mxnet/issues/1480\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train:21563 valid:3080 test:6162\n"
     ]
    }
   ],
   "source": [
    "train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XrayData(Dataset):\n",
    "    def __init__(self, img_dir, lbl_file, patient_ids, transform=None):\n",
    "        \n",
    "        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)\n",
    "        self.transform = transform\n",
    "        print(\"Loaded {} labels and {} images\".format(len(self.labels), len(self.img_locs)))\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        im_file = self.img_locs[idx]\n",
    "        im_rgb = Image.open(im_file)\n",
    "        label = self.labels[idx]\n",
    "        im_rgb = mx.nd.array(im_rgb)\n",
    "        if self.transform is not None:\n",
    "            im_rgb = self.transform(im_rgb)\n",
    "\n",
    "        return im_rgb, mx.nd.array(label)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.img_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_augmentation_dataset(img_dir, lbl_file, patient_ids, normalize):\n",
    "    dataset = XrayData(img_dir, lbl_file, patient_ids,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.Resize(WIDTH),\n",
    "                           transforms.ToTensor(),  \n",
    "                           transforms.Normalize(IMAGENET_RGB_MEAN, IMAGENET_RGB_SD)]))\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 87306 labels and 87306 images\n"
     ]
    }
   ],
   "source": [
    "# Dataset for training\n",
    "train_dataset = XrayData(img_dir=IMAGE_FOLDER,\n",
    "                         lbl_file=LABEL_FILE,\n",
    "                         patient_ids=train_set,\n",
    "                         transform=transforms.Compose([\n",
    "                             transforms.RandomResizedCrop(size=WIDTH),\n",
    "                             transforms.RandomFlipLeftRight(),\n",
    "                             transforms.ToTensor(),\n",
    "                             transforms.Normalize(IMAGENET_RGB_MEAN, IMAGENET_RGB_SD)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 7616 labels and 7616 images\n",
      "Loaded 17198 labels and 17198 images\n"
     ]
    }
   ],
   "source": [
    "valid_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, valid_set, transforms.Normalize(IMAGENET_RGB_MEAN, IMAGENET_RGB_SD))\n",
    "test_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, test_set, transforms.Normalize(IMAGENET_RGB_MEAN, IMAGENET_RGB_SD))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DataLoaders\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=True, num_workers=6, last_batch='discard')\n",
    "valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=False, num_workers=6, last_batch='discard')\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=BATCHSIZE,\n",
    "                         shuffle=False, num_workers=6, last_batch='discard')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx = [mx.gpu(i) for i in range(GPU_COUNT)]   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = mx.gluon.model_zoo.vision.densenet121(pretrained=True, ctx=ctx)\n",
    "with net.name_scope():\n",
    "    net.output = mx.gluon.nn.Dense(CLASSES)\n",
    "net.output.initialize(ctx=ctx)\n",
    "net.hybridize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': LR})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "binary_cross_entropy = gluon.loss.SigmoidBinaryCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "sig = gluon.nn.Activation('sigmoid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net):\n",
    "    acc = 0\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data_split = gluon.utils.split_and_load(data, ctx)\n",
    "        label_split = gluon.utils.split_and_load(label, ctx)\n",
    "        outputs = [(sig(net(X)),Y) for X, Y in zip(data_split, label_split)]\n",
    "        for output, label in outputs:\n",
    "            acc += float((label.asnumpy() == np.round(output.asnumpy())).sum()) / CLASSES / output.shape[0]\n",
    "    data_split = gluon.utils.split_and_load(data, [mx.cpu()])\n",
    "    label_split = gluon.utils.split_and_load(label, [mx.cpu()])\n",
    "    return acc/i/len(ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_batch = 5 # Blocking call every 5 batches\n",
    "n_print = 100 # Print every 100 batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(net, dataloader, trainer, loss_fn, ctx, n_batch=7, n_print=100):\n",
    "    losses_acc = [mx.nd.zeros((1), ctx=c) for c in ctx]\n",
    "    print_loss = 0\n",
    "    for i, (data, label) in enumerate(dataloader):        \n",
    "        data_split = gluon.utils.split_and_load(data, ctx)\n",
    "        label_split = gluon.utils.split_and_load(label, ctx)  \n",
    "        \n",
    "        if i > 0:\n",
    "            for j, l in enumerate(losses):\n",
    "                # Accumulate losses asynchronously on each GPU\n",
    "                losses_acc[j] += l.mean()\n",
    "            if i%n_batch == 0:\n",
    "                # Blocking call\n",
    "                print_loss = 0\n",
    "                for l in losses_acc:\n",
    "                    print_loss += l.asscalar()\n",
    "                l = l / (i+1)/len(ctx)\n",
    "            if i%n_print == 0:\n",
    "                print('Batch {0}: Loss: {1:.4f}'.format(i, print_loss))            \n",
    "            \n",
    "        with autograd.record():\n",
    "            losses = [loss_fn(net(X), Y) for X, Y in zip(data_split, label_split)]\n",
    "        for l in losses:\n",
    "            l.backward()\n",
    "        trainer.step(data.shape[0]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 100: Loss: 82.8983\n",
      "Batch 200: Loss: 148.3790\n",
      "Batch 300: Loss: 212.7993\n",
      "Epoch 0, 0.985651 test_accuracy after 122.56 seconds\n",
      "Batch 100: Loss: 63.2765\n",
      "Batch 200: Loss: 125.7693\n",
      "Batch 300: Loss: 187.9143\n",
      "Epoch 1, 0.986099 test_accuracy after 114.79 seconds\n",
      "Batch 100: Loss: 61.0549\n",
      "Batch 200: Loss: 122.8533\n",
      "Batch 300: Loss: 184.8669\n",
      "Epoch 2, 0.985880 test_accuracy after 116.16 seconds\n",
      "Batch 100: Loss: 60.4399\n",
      "Batch 200: Loss: 121.4686\n",
      "Batch 300: Loss: 182.2017\n",
      "Epoch 3, 0.985979 test_accuracy after 114.90 seconds\n",
      "Batch 100: Loss: 60.2344\n",
      "Batch 200: Loss: 120.4585\n",
      "Batch 300: Loss: 180.7013\n",
      "Epoch 4, 0.985700 test_accuracy after 115.12 seconds\n",
      "CPU times: user 31min 24s, sys: 11min 1s, total: 42min 26s\n",
      "Wall time: 9min 43s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Main training loop: 28min 40s\n",
    "# 4 GPU - Main training loop: 9min 43s\n",
    "for e in range(EPOCHS):\n",
    "    tick = time()\n",
    "    train_epoch(net, train_loader, trainer, binary_cross_entropy, ctx)\n",
    "    test_accuracy = evaluate_accuracy(valid_loader, net)\n",
    "    print('Epoch {0}, {1:.6f} test_accuracy after {2:.2f} seconds'.format(e, test_accuracy, time()-tick))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 20.1 s, sys: 7.22 s, total: 27.3 s\n",
      "Wall time: 13.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "predictions = np.zeros((0, CLASSES))\n",
    "labels = np.zeros((0, CLASSES))\n",
    "for (data, label) in (test_loader):        \n",
    "    data_split = gluon.utils.split_and_load(data, ctx)\n",
    "    label_split = gluon.utils.split_and_load(label, ctx)  \n",
    "    outputs = [sig(net(X)) for X in data_split]\n",
    "    predictions = np.concatenate([predictions, np.concatenate([output.asnumpy() for output in outputs])])\n",
    "    labels = np.concatenate([labels, np.concatenate([label.asnumpy() for label in label_split])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full AUC [0.8125718127685243, 0.8551507280005255, 0.8071080700770485, 0.8908148250913325, 0.8860359865394066, 0.9254589157853671, 0.7143975677357627, 0.8279628814488093, 0.6307513183430535, 0.8473062753975474, 0.7456833776349139, 0.8078839618223166, 0.7708621556491738, 0.8808694022714751]\n",
      "Validation AUC: 0.8145\n"
     ]
    }
   ],
   "source": [
    "# 1 GPU AUC: 0.8235\n",
    "# 4 GPU AUC: 0.8145\n",
    "print(\"Validation AUC: {0:.4f}\".format(compute_roc_auc(labels, predictions, CLASSES)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Synthetic Data (Pure Training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87296\n"
     ]
    }
   ],
   "source": [
    "# Test on fake-data -> no IO lag\n",
    "batch_in_epoch = len(train_dataset.labels)//BATCHSIZE\n",
    "tot_num = batch_in_epoch * BATCHSIZE\n",
    "print(tot_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "fake_X = mx.nd.ones((tot_num, 3, 224, 224), dtype=np.float32)\n",
    "fake_y = mx.nd.ones((tot_num, CLASSES), dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset_synth = ArrayDataset(fake_X, fake_y)\n",
    "train_dataloader_synth = DataLoader(train_dataset_synth, BATCHSIZE, shuffle=False, num_workers=0, last_batch='discard')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 100: Loss: 61.9009\n",
      "Batch 200: Loss: 61.9074\n",
      "Batch 300: Loss: 61.9134\n",
      "Epoch 0, 102.29 seconds\n",
      "Batch 100: Loss: 0.0054\n",
      "Batch 200: Loss: 0.0103\n",
      "Batch 300: Loss: 0.0148\n",
      "Epoch 1, 101.39 seconds\n",
      "Batch 100: Loss: 0.0042\n",
      "Batch 200: Loss: 0.0080\n",
      "Batch 300: Loss: 0.0115\n",
      "Epoch 2, 100.83 seconds\n",
      "Batch 100: Loss: 0.0030\n",
      "Batch 200: Loss: 0.0059\n",
      "Batch 300: Loss: 0.0086\n",
      "Epoch 3, 100.72 seconds\n",
      "Batch 100: Loss: 0.0026\n",
      "Batch 200: Loss: 0.0050\n",
      "Batch 300: Loss: 0.0073\n",
      "Epoch 4, 101.25 seconds\n",
      "CPU times: user 30min 46s, sys: 10min 35s, total: 41min 22s\n",
      "Wall time: 8min 26s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Main training loop: 27min 45s\n",
    "# 4 GPU - Main training loop: 8min 26s\n",
    "n_batch = 100\n",
    "for e in range(EPOCHS):\n",
    "    tick = time()\n",
    "    train_epoch(net, train_dataloader_synth, trainer, binary_cross_entropy, ctx)\n",
    "    nd.waitall()\n",
    "    print('Epoch {0}, {1:.2f} seconds'.format(e, time()-tick))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py35]",
   "language": "python",
   "name": "conda-env-py35-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
